{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp dispatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from __future__ import annotations\n",
    "from fastcore.imports import *\n",
    "from fastcore.foundation import *\n",
    "from fastcore.utils import *\n",
    "\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nbdev.showdoc import *\n",
    "from fastcore.test import *\n",
    "from fastcore.nb_imports import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Type dispatch\n",
    "\n",
    "> Basic single and dual parameter dispatch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def lenient_issubclass(cls, types):\n",
    "    \"If possible return whether `cls` is a subclass of `types`, otherwise return False.\"\n",
    "    if cls is object and types is not object: return False # treat `object` as highest level\n",
    "    try: return isinstance(cls, types) or issubclass(cls, types)\n",
    "    except: return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert not lenient_issubclass(typing.Collection, list)\n",
    "assert lenient_issubclass(list, typing.Collection)\n",
    "assert lenient_issubclass(typing.Collection, object)\n",
    "assert lenient_issubclass(typing.List, typing.Collection)\n",
    "assert not lenient_issubclass(typing.Collection, typing.List)\n",
    "assert not lenient_issubclass(object, typing.Callable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n",
    "    \"Return a new list containing all items from the iterable sorted topologically\"\n",
    "    l,res = L(list(iterable)),[]\n",
    "    for _ in range(len(l)):\n",
    "        t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n",
    "        res.append(t), l.remove(t)\n",
    "    return res[::-1] if reverse else res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "td = [3, 1, 2, 5]\n",
    "test_eq(sorted_topologically(td), [1, 2, 3, 5])\n",
    "test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "td = {int:1, numbers.Number:2, numbers.Integral:3}\n",
    "test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "td = [numbers.Integral, tuple, list, int, dict]\n",
    "td = sorted_topologically(td, cmp=lenient_issubclass)\n",
    "assert td.index(int) < td.index(numbers.Integral)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _chk_defaults(f, ann):\n",
    "    pass\n",
    "# Implementation removed until we can figure out how to do this without `inspect` module\n",
    "#     try: # Some callables don't have signatures, so ignore those errors\n",
    "#         params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]\n",
    "#         if any(p.default!=inspect.Parameter.empty for p in params):\n",
    "#             warn(f\"{f.__name__} has default params. These will be ignored.\")\n",
    "#     except ValueError: pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _p2_anno(f):\n",
    "    \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n",
    "    hints = type_hints(f)\n",
    "    ann = [o for n,o in hints.items() if n!='return']\n",
    "    if callable(f): _chk_defaults(f, ann)\n",
    "    while len(ann)<2: ann.append(object)\n",
    "    return ann[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "def _f(a): pass\n",
    "test_eq(_p2_anno(_f), (object,object))\n",
    "def _f(a, b): pass\n",
    "test_eq(_p2_anno(_f), (object,object))\n",
    "def _f(a:None, b)->str: pass\n",
    "test_eq(_p2_anno(_f), (NoneType,object))\n",
    "def _f(a:str, b)->float: pass\n",
    "test_eq(_p2_anno(_f), (str,object))\n",
    "def _f(a:None, b:str)->float: pass\n",
    "test_eq(_p2_anno(_f), (NoneType,str))\n",
    "def _f(a:int, b:int)->float: pass\n",
    "test_eq(_p2_anno(_f), (int,int))\n",
    "def _f(self, a:int, b:int): pass\n",
    "test_eq(_p2_anno(_f), (int,int))\n",
    "def _f(a:int, b:str)->float: pass\n",
    "test_eq(_p2_anno(_f), (int,str))\n",
    "test_eq(_p2_anno(attrgetter('foo')), (object,object))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([object, object], [int, object])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|hide\n",
    "# Disabled until _chk_defaults fixed\n",
    "# def _f(x:int,y:int=10): pass\n",
    "# test_warns(lambda: _p2_anno(_f))\n",
    "def _f(x:int,y=10): pass\n",
    "_p2_anno(None),_p2_anno(_f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TypeDispatch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies.  This is a prominent feature in some  programming languages like Julia.  For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n",
    "\n",
    "```julia\n",
    "collide_with(x::Asteroid, y::Asteroid) = ... \n",
    "# deal with asteroid hitting asteroid\n",
    "\n",
    "collide_with(x::Asteroid, y::Spaceship) = ... \n",
    "# deal with asteroid hitting spaceship\n",
    "\n",
    "collide_with(x::Spaceship, y::Asteroid) = ... \n",
    "# deal with spaceship hitting asteroid\n",
    "\n",
    "collide_with(x::Spaceship, y::Spaceship) = ... \n",
    "# deal with spaceship hitting spaceship\n",
    "```\n",
    "\n",
    "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n",
    "\n",
    "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions,  which ensures that the proper function is called when passed inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TypeDict:\n",
    "    def __init__(self): self.d,self.cache = {},{}\n",
    "\n",
    "    def _reset(self):\n",
    "        self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n",
    "        self.cache = {}\n",
    "\n",
    "    def add(self, t, f):\n",
    "        \"Add type `t` and function `f`\"\n",
    "        if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))\n",
    "        for t_ in t: self.d[t_] = f\n",
    "        self._reset()\n",
    "\n",
    "    def all_matches(self, k):\n",
    "        \"Find first matching type that is a super-class of `k`\"\n",
    "        if k not in self.cache:\n",
    "            types = [f for f in self.d if lenient_issubclass(k,f)]\n",
    "            self.cache[k] = [self.d[o] for o in types]\n",
    "        return self.cache[k]\n",
    "\n",
    "    def __getitem__(self, k):\n",
    "        \"Find first matching type that is a super-class of `k`\"\n",
    "        res = self.all_matches(k)\n",
    "        return res[0] if len(res) else None\n",
    "\n",
    "    def __repr__(self): return self.d.__repr__()\n",
    "    def first(self): return first(self.d.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TypeDispatch:\n",
    "    \"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`\"\n",
    "    def __init__(self, funcs=(), bases=()):\n",
    "        self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n",
    "        for o in L(funcs): self.add(o)\n",
    "        self.inst = None\n",
    "        self.owner = None\n",
    "\n",
    "    def add(self, f):\n",
    "        \"Add type `t` and function `f`\"\n",
    "        if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n",
    "        else: a0,a1 = _p2_anno(f)\n",
    "        t = self.funcs.d.get(a0)\n",
    "        if t is None:\n",
    "            t = _TypeDict()\n",
    "            self.funcs.add(a0, t)\n",
    "        t.add(a1, f)\n",
    "\n",
    "    def first(self):\n",
    "        \"Get first function in ordered dict of type:func.\"\n",
    "        return self.funcs.first().first()\n",
    "\n",
    "    def returns(self, x):\n",
    "        \"Get the return type of annotation of `x`.\"\n",
    "        return anno_ret(self[type(x)])\n",
    "\n",
    "    def _attname(self,k): return getattr(k,'__name__',str(k))\n",
    "    def __repr__(self):\n",
    "        r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, \"__name__\", type(v).__name__)}'\n",
    "             for k in self.funcs.d for l,v in self.funcs[k].d.items()]\n",
    "        r = r + [o.__repr__() for o in self.bases]\n",
    "        return '\\n'.join(r)\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        ts = L(args).map(type)[:2]\n",
    "        f = self[tuple(ts)]\n",
    "        if not f: return args[0]\n",
    "        if isinstance(f, staticmethod): f = f.__func__\n",
    "        elif self.inst is not None: f = MethodType(f, self.inst)\n",
    "        elif self.owner is not None: f = MethodType(f, self.owner)\n",
    "        return f(*args, **kwargs)\n",
    "\n",
    "    def __get__(self, inst, owner):\n",
    "        self.inst = inst\n",
    "        self.owner = owner\n",
    "        return self\n",
    "\n",
    "    def __getitem__(self, k):\n",
    "        \"Find first matching type that is a super-class of `k`\"\n",
    "        k = L(k)\n",
    "        while len(k)<2: k.append(object)\n",
    "        r = self.funcs.all_matches(k[0])\n",
    "        for t in r:\n",
    "            o = t[k[1]]\n",
    "            if o is not None: return o\n",
    "        for base in self.bases:\n",
    "            res = base[k]\n",
    "            if res is not None: return res\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how `TypeDispatch` works, we define a set of functions that accept a variety of input types, specified with different type annotations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f2(x:int, y:float): return x+y              #int and float for 2nd arg\n",
    "def f_nin(x:numbers.Integral)->int:  return x+1 #integral numeric\n",
    "def f_ni2(x:int): return x                      #integer\n",
    "def f_bll(x:bool|list): return x              #bool or list\n",
    "def f_num(x:numbers.Number): return x           #Number (root of numerics)          "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can optionally initialize `TypeDispatch` with a list of functions we want to search.  Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(bool,object) -> f_bll\n",
       "(int,object) -> f_ni2\n",
       "(Integral,object) -> f_nin\n",
       "(Number,object) -> f_num\n",
       "(list,object) -> f_bll\n",
       "(object,object) -> NoneType"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n",
    "t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that only the first two arguments are used for `TypeDispatch`.  If your function only contains one argument, the second parameter will be shown as `object`.  If you pass `None` into `TypeDispatch`, then this will be displayed as `(object, object) -> NoneType`.\n",
    "\n",
    "`TypeDispatch` is a dictionary-like object, which means that you can retrieve a function by the associated type annotation.  For example, the statement:\n",
    "\n",
    "```py\n",
    "t[float]\n",
    "```\n",
    "Will return `f_num` because that is the matching function that has a type annotation that is a super-class of of `float` - `numbers.Number`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert issubclass(float, numbers.Number)\n",
    "test_eq(t[float], f_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The same is true for other types as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t[np.int32], f_nin)\n",
    "test_eq(t[bool], f_bll)\n",
    "test_eq(t[list], f_bll)\n",
    "test_eq(t[np.int32], f_nin)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you try to get a type that doesn't match, `TypeDispatch` will return `None`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t[str], None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TypeDispatch.add\" class=\"doc_header\"><code>TypeDispatch.add</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TypeDispatch.add</code>(**`f`**)\n",
       "\n",
       "Add type `t` and function `f`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TypeDispatch.add)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method allows you to add an additional function to an existing `TypeDispatch` instance :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(bool,object) -> f_bll\n",
       "(int,object) -> f_ni2\n",
       "(Integral,object) -> f_nin\n",
       "(Number,object) -> f_num\n",
       "(list,object) -> f_bll\n",
       "(typing.Collection,object) -> f_col\n",
       "(object,object) -> NoneType"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f_col(x:typing.Collection): return x\n",
    "t.add(f_col)\n",
    "test_eq(t[str], f_col)\n",
    "t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you accidentally add the same function more than once things will still work as expected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t.add(f_ni2) \n",
    "test_eq(t[int], f_ni2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f_ni3(z:int): return z # collides with f_ni2 with same type annotations\n",
    "t.add(f_ni3) \n",
    "test_eq(t[int], f_ni3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using `bases`:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The argument `bases` can optionally accept a single instance of `TypeDispatch` or a collection (i.e. a tuple or list) of `TypeDispatch` objects.  This can provide functionality similar to multiple inheritance. \n",
    "\n",
    "These are searched for matching functions if no match in your list of functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(str,object) -> f_str\n",
       "(bool,object) -> f_bll\n",
       "(int,object) -> f_ni2\n",
       "(Integral,object) -> f_nin\n",
       "(Number,object) -> f_num\n",
       "(list,object) -> f_bll\n",
       "(object,object) -> NoneType"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f_str(x:str): return x+'1'\n",
    "\n",
    "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n",
    "t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.\n",
    "t2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t2[int], f_ni2)       # searches `t` b/c not found in `t2`\n",
    "test_eq(t2[np.int32], f_nin)  # searches `t` b/c not found in `t2`\n",
    "test_eq(t2[float], f_num)     # searches `t` b/c not found in `t2`\n",
    "test_eq(t2[bool], f_bll)      # searches `t` b/c not found in `t2`\n",
    "test_eq(t2[str], f_str)       # found in `t`!\n",
    "test_eq(t2('a'), 'a1')        # found in `t`!, and uses __call__\n",
    "\n",
    "o = np.int32(1)\n",
    "test_eq(t2(o), 2)             # found in `t2` and uses __call__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Up To Two Arguments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TypeDispatch` supports up to two arguments when searching for the appropriate function.  The following functions `f1` and `f2` both have two parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(int,float) -> f2\n",
       "(Integral,object) -> f1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f1(x:numbers.Integral, y): return x+1  #Integral is a numeric type\n",
    "def f2(x:int, y:float): return x+y\n",
    "t = TypeDispatch([f1,f2])\n",
    "t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " You can lookup functions from a `TypeDispatch` instance with two parameters like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t[np.int32], f1)\n",
    "test_eq(t[int,float], f2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added.  In the below example, `f1` is ignored in favor of `f2` because the first two parameters have identical type hints:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(str,int) -> f2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f1(a:str, b:int, c:list): return a\n",
    "def f2(a: str, b:int): return b\n",
    "t = TypeDispatch([f1,f2])\n",
    "test_eq(t[str, int], f2)\n",
    "t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Matching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Type Dispatch` matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions.  \n",
    "\n",
    "Let's consider an example where we try to retrieve the function corresponding to types of `[np.int32, float]`.\n",
    "\n",
    "In this scenario, `f2` will not be matched. This is because the first type annotation of `f2`, `int`, is not a superclass (or the same class) of `np.int32`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(x:numbers.Integral, y): return x+1\n",
    "def f2(x:int, y:float): return x+y\n",
    "t = TypeDispatch([f1,f2])\n",
    "\n",
    "assert not issubclass(np.int32, int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead, `f1` is a valid match, as its first argument is annoted with the type `numbers.Integeral`, which `np.int32` is a subclass of:  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert issubclass(np.int32, numbers.Integral)\n",
    "test_eq(t[np.int32,float], f1) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In `f1` , the 2nd parameter `y` is not annotated, which means `TypeDispatch` will match anything where the first argument matches `int` that is not matched with anything else:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral\n",
    "test_eq(t[int], f1)\n",
    "test_eq(t[int,int], f1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If no match is possible, `None` is returned:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t[float,float], None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TypeDispatch.__call__\" class=\"doc_header\"><code>TypeDispatch.__call__</code><a href=\"__main__.py#L35\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TypeDispatch.__call__</code>(**\\*`args`**, **\\*\\*`kwargs`**)\n",
       "\n",
       "Call self as a function."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TypeDispatch.__call__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TypeDispatch` is also callable.  When you call an instance of `TypeDispatch`, it will execute the relevant function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f_arr(x:np.ndarray): return x.sum()\n",
    "def f_int(x:np.int32): return x+1\n",
    "t = TypeDispatch([f_arr, f_int])\n",
    "\n",
    "arr = np.array([5,4,3,2,1])\n",
    "test_eq(t(arr), 15) # dispatches to f_arr\n",
    "\n",
    "o = np.int32(1)\n",
    "test_eq(t(o), 2) # dispatches to f_int\n",
    "assert t.first() is not None "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also call an instance of of `TypeDispatch` when there are two parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(x:numbers.Integral, y): return x+1\n",
    "def f2(x:int, y:float): return x+y\n",
    "t = TypeDispatch([f1,f2])\n",
    "\n",
    "test_eq(t(3,2.0), 5)\n",
    "test_eq(t(3,2), 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When no match is found, a `TypeDispatch` instance becomes an identity function.  This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t('a'), 'a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TypeDispatch.returns\" class=\"doc_header\"><code>TypeDispatch.returns</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TypeDispatch.returns</code>(**`x`**)\n",
       "\n",
       "Get the return type of annotation of `x`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TypeDispatch.returns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can optionally pass an object to `TypeDispatch.returns` and get the return type annotation back:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1(x:int) -> np.ndarray: return np.array(x)\n",
    "def f2(x:str) -> float: return List\n",
    "def f3(x:float): return List # f3 has no return type annotation\n",
    "\n",
    "t = TypeDispatch([f1, f2, f3])\n",
    "\n",
    "test_eq(t.returns(1), np.ndarray)  # dispatched to f1\n",
    "test_eq(t.returns('Hello'), float) # dispatched to f2\n",
    "test_eq(t.returns(1.0), None)      # dispatched to f3\n",
    "\n",
    "class _Test: pass\n",
    "_test = _Test()\n",
    "test_eq(t.returns(_test), None) # type `_Test` not found, so None returned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using TypeDispatch With Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use `TypeDispatch` when defining methods as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_nin(self, x:str|numbers.Integral): return str(x)+'1'\n",
    "def m_bll(self, x:bool): self.foo='a'\n",
    "def m_num(self, x:numbers.Number): return x*2\n",
    "\n",
    "t = TypeDispatch([m_nin,m_num,m_bll])\n",
    "class A: f = t # set class attribute `f` equal to a TypeDispatch instance\n",
    "    \n",
    "a = A()\n",
    "test_eq(a.f(1), '11')  #dispatch to m_nin\n",
    "test_eq(a.f(1.), 2.)   #dispatch to m_num\n",
    "test_is(a.f.inst, a)\n",
    "\n",
    "a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'\n",
    "test_eq(a.foo, 'a')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As discussed in `TypeDispatch.__call__`, when there is not a match, `TypeDispatch.__call__` becomes an identity function.  In the below example, a tuple does not match any type annotations so a tuple is returned:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(a.f(()), ()) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We extend the previous example by using `bases` to add an additional method that supports tuples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_tup(self, x:tuple): return x+(1,)\n",
    "t2 = TypeDispatch(m_tup, bases=t)\n",
    "\n",
    "class A2: f = t2\n",
    "a2 = A2()\n",
    "test_eq(a2.f(1), '11')\n",
    "test_eq(a2.f(1.), 2.)\n",
    "test_is(a2.f.inst, a2)\n",
    "a2.f(False)\n",
    "test_eq(a2.foo, 'a')\n",
    "test_eq(a2.f(()), (1,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using TypeDispatch With Class Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use `TypeDispatch` when defining class methods too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_nin(cls, x:str|numbers.Integral): return str(x)+'1'\n",
    "def m_bll(cls, x:bool): cls.foo='a'\n",
    "def m_num(cls, x:numbers.Number): return x*2\n",
    "\n",
    "t = TypeDispatch([m_nin,m_num,m_bll])\n",
    "class A: f = t # set class attribute `f` equal to a TypeDispatch\n",
    "\n",
    "test_eq(A.f(1), '11')  #dispatch to m_nin\n",
    "test_eq(A.f(1.), 2.)   #dispatch to m_num\n",
    "test_is(A.f.owner, A)\n",
    "\n",
    "A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'\n",
    "test_eq(A.foo, 'a')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## typedispatch Decorator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class DispatchReg:\n",
    "    \"A global registry for `TypeDispatch` objects keyed by function name\"\n",
    "    def __init__(self): self.d = defaultdict(TypeDispatch)\n",
    "    def __call__(self, f):\n",
    "        if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'\n",
    "        else: nm = f'{f.__qualname__}'\n",
    "        if isinstance(f, classmethod): f=f.__func__\n",
    "        self.d[nm].add(f)\n",
    "        return self.d[nm]\n",
    "\n",
    "typedispatch = DispatchReg()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@typedispatch\n",
    "def f_td_test(x, y): return f'{x}{y}'\n",
    "@typedispatch\n",
    "def f_td_test(x:numbers.Integral|int, y): return x+1\n",
    "@typedispatch\n",
    "def f_td_test(x:int, y:float): return x+y\n",
    "@typedispatch\n",
    "def f_td_test(x:int, y:int): return x*y\n",
    "\n",
    "test_eq(f_td_test(3,2.0), 5)\n",
    "assert issubclass(int, numbers.Integral)\n",
    "test_eq(f_td_test(3,2), 6)\n",
    "\n",
    "test_eq(f_td_test('a','b'), 'ab')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using typedispatch With other decorators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use `typedispatch` with `classmethod` and `staticmethod` decorator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class A:\n",
    "    @typedispatch\n",
    "    def f_td_test(self, x:numbers.Integral, y): return x+1\n",
    "    @typedispatch\n",
    "    @classmethod\n",
    "    def f_td_test(cls, x:int, y:float): return x+y\n",
    "    @typedispatch\n",
    "    @staticmethod\n",
    "    def f_td_test(x:int, y:int): return x*y\n",
    "    \n",
    "test_eq(A.f_td_test(3,2), 6)\n",
    "test_eq(A.f_td_test(3,2.0), 5)\n",
    "test_eq(A().f_td_test(3,'2.0'), 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Casting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we can dispatch on types, let's make it easier to cast objects to a different type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "_all_=['cast']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def retain_meta(x, res, as_copy=False):\n",
    "    \"Call `res.set_meta(x)`, if it exists\"\n",
    "    if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def default_set_meta(self, x, as_copy=False):\n",
    "    \"Copy over `_meta` from `x` to `res`, if it's missing\"\n",
    "    if hasattr(x, '_meta') and not hasattr(self, '_meta'):\n",
    "        meta = x._meta\n",
    "        if as_copy: meta = copy(meta)\n",
    "        self._meta = meta\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "@typedispatch\n",
    "def cast(x, typ):\n",
    "    \"cast `x` to type `typ` (may also change `x` inplace)\"\n",
    "    res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x\n",
    "    if risinstance('ndarray', res): res = res.view(typ)\n",
    "    elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)\n",
    "    else:\n",
    "        try: res.__class__ = typ\n",
    "        except: res = typ(res)\n",
    "    return retain_meta(x, res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This works both for plain python classes:..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mk_class('_T1', 'a')   # mk_class is a fastai utility that constructs a class.\n",
    "class _T2(_T1): pass\n",
    "\n",
    "t = _T1(a=1)\n",
    "t2 = cast(t, _T2)        \n",
    "assert t2 is t            # t2 refers to the same object as t\n",
    "assert isinstance(t, _T2) # t also changed in-place\n",
    "assert isinstance(t2, _T2)\n",
    "\n",
    "test_eq_type(_T2(a=1), t2) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...as well as for arrays and tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T1(ndarray): pass\n",
    "\n",
    "t = array([1])\n",
    "t2 = cast(t, _T1)\n",
    "test_eq(array([1]), t2)\n",
    "test_eq(_T1, type(t2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To customize casting for other types, define a separate `cast` function with `typedispatch` for your type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def retain_type(new, old=None, typ=None, as_copy=False):\n",
    "    \"Cast `new` to type of `old` or `typ` if it's a superclass\"\n",
    "    # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing\n",
    "    if new is None: return\n",
    "    assert old is not None or typ is not None\n",
    "    if typ is None:\n",
    "        if not isinstance(old, type(new)): return new\n",
    "        typ = old if isinstance(old,type) else type(old)\n",
    "    # Do nothing the new type is already an instance of requested type (i.e. same type)\n",
    "    if typ==NoneType or isinstance(new, typ): return new\n",
    "    return retain_meta(old, cast(new, typ), as_copy=as_copy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(tuple): pass\n",
    "a = _T((1,2))\n",
    "b = tuple((1,2))\n",
    "c = retain_type(b, typ=_T)\n",
    "test_eq_type(c, a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `old` has a `_meta` attribute, its content is passed when casting `new` to the type of `old`.  In the below example, only the attribute `a`, but not `other_attr` is kept, because `other_attr` is not in `_meta`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _A():\n",
    "    set_meta = default_set_meta\n",
    "    def __init__(self, t): self.t=t\n",
    "\n",
    "class _B1(_A):\n",
    "    def __init__(self, t, a=1):\n",
    "        super().__init__(t)\n",
    "        self._meta = {'a':a}\n",
    "        self.other_attr = 'Hello' # will not be kept after casting.\n",
    "        \n",
    "x = _B1(1, a=2)\n",
    "b = _A(1)\n",
    "c = retain_type(b, old=x)\n",
    "test_eq(c._meta, {'a': 2})\n",
    "assert not getattr(c, 'other_attr', None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def retain_types(new, old=None, typs=None):\n",
    "    \"Cast each item of `new` to type of matching item in `old` if it's a superclass\"\n",
    "    if not is_listy(new): return retain_type(new, old, typs)\n",
    "    if typs is not None:\n",
    "        if isinstance(typs, dict):\n",
    "            t = first(typs.keys())\n",
    "            typs = typs[t]\n",
    "        else: t,typs = typs,None\n",
    "    else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)\n",
    "    return t(L(new, old, typs).map_zip(retain_types, cycled=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T(tuple): pass\n",
    "\n",
    "t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))\n",
    "test_eq_type(t1, 1)\n",
    "test_eq_type(t2, T((1,T((1,1)))))\n",
    "\n",
    "t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})\n",
    "test_eq_type(t1, 1)\n",
    "test_eq_type(t2, T((1,T((1,1)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def explode_types(o):\n",
    "    \"Return the type of `o`, potentially in nested dictionaries for thing that are listy\"\n",
    "    if not is_listy(o): return type(o)\n",
    "    return {type(o): [explode_types(o_) for o_ in o]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_test.ipynb.\n",
      "Converted 01_basics.ipynb.\n",
      "Converted 02_foundation.ipynb.\n",
      "Converted 03_xtras.ipynb.\n",
      "Converted 03a_parallel.ipynb.\n",
      "Converted 03b_net.ipynb.\n",
      "Converted 04_dispatch.ipynb.\n",
      "Converted 05_transform.ipynb.\n",
      "Converted 06_docments.ipynb.\n",
      "Converted 07_meta.ipynb.\n",
      "Converted 08_script.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted parallel_win.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#|hide\n",
    "#|eval: false\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
